template <typename T>
__device__ void index2onehot_forward(T *input, T *output, int num, int channels, int spatial_dim) {
  int n = threadIdx.x + blockIdx.x * blockDim.x;
  int s = threadIdx.y + blockIdx.y * blockDim.y;
  if (n >= num || s >= spatial_dim)
    return;

  int idx_src = s + spatial_dim * n;
  int i_class = static_cast<int>(input[idx_src]);
  int idx_dst = s + spatial_dim * (i_class + channels * n);
  output[idx_dst] = 1;
}

extern "C" {
  __global__ void index2onehot_forward_float(float *input, float *output, int num, int channels, int spatial_dim) {
    index2onehot_forward(input, output, num, channels, spatial_dim);
  }
  __global__ void index2onehot_forward_double(double *input, double *output, int num, int channels, int spatial_dim) {
    index2onehot_forward(input, output, num, channels, spatial_dim);
  }
}

// vim: ft=cuda
